"""
Import needed packages
"""
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from pandas.tseries.offsets import MonthEnd


"""
Define supporting functions
"""


def col2monthend(df: object, col: object, d_format: object = '%Y-%m-%d') -> object:
    """
    convert string/int/float to datetime at monthend
    :param df: Input DataFrame
    :param col: Column name to be converted
    :param d_format: Format of the date to be converted
    """
    if df[col].dtype == 'float64':
        df[col] = df[col].fillna(-999)
        df[col] = df[col].astype(int).astype(str)
        df[col] = df[col].replace('-999', np.nan)
    else:
        df[col] = df[col].astype(str)
    df[col] = pd.to_datetime(df[col], format=d_format, errors='coerce')
    df[col] = df[col].astype('datetime64[M]') + MonthEnd(0)
    return df


def jobs_grp_cumulate(grp, nn, jt):
    """
    Accumulates new hires into total hires according to the following equation: x_it = (1 - s(x)_t) * x_i(t-1) + h_it
    where:  - x can be [L, l, lambda] i.e. the total jobs for AI, OT and DM jobs
            - s(x)_t are monthly separation rate from the BLS which are type specific (average for all jobs of type x)
    :param grp: dataframe per employer group
    :param nn: beginning of the name of the variable to be constructed
    :param jt: type of jobs to be considered, can be ['AI', 'OldTech', 'DataMgmt]
    :return: group specific DataFrame with the cumulated variables
    """
    if grp[nn + '_' + jt].sum() >= 1:
        grp = grp.reset_index(drop=True)
        start = grp[grp[nn + '_' + jt] != 0].index[0]
        for i in range(start + 1, len(grp)):
            grp.loc[i, nn + '_' + jt] = ((1 - grp.loc[i, 's_' + jt]) * grp.loc[i - 1, nn + '_' + jt]) + \
                                        (grp.loc[i, 'is' + jt] * grp.loc[i, 'h_o_' + jt])
    return grp


def outcome_merge_labor(df, sample):
    """
    Merge the labor information to outcome of characteristics
    :param df: the characteristics under optimal parameter combination
    :param sample: sample employment information for analysis
    :return: merged dataframe of labor information and characteristics
    """
    labor = sample[['emp', 'yearmonth', 'l_ai', 'l_ot', 'l_dm', 'GS1M', 'w_AI', 'w_OT', 'w_DM']]
    labor['l_ai_ot'] = labor.l_ai + labor.l_ot
    df['yearmonth'] = pd.to_datetime(df['yearmonth'])
    df = df.merge(labor, on=['emp', 'yearmonth'])
    df = df.sort_values(['emp', 'yearmonth'])
    return df


def get_params(pps):
    """
    Get parameters separately from dictionary
    :param pps: a dictionary of optimal parameter combination
    :return: separate parameters
    """
    alpha_p = pps['alpha']
    gamma_p = pps['gamma']
    delta_p = pps['delta']
    phi_p = pps['phi']
    d_0_av_p = pps['d_0_av']
    return alpha_p, gamma_p, delta_p, phi_p, d_0_av_p


def plot_productivity(df, path, plots_colors, font=18):
    """
    Plot the distribution of average productivity parameters over time
    :param df: (pandas DataFrame) the characteristics under optimal parameter combination
    :param path: (str) path to save output plots
    :param plots_colors: (dict) dictionary containing plot colors by variable name
    :param font: (int) font of table text
    :return: plot and save locally
    """
    df_mean = df.groupby('yearmonth').mean().reset_index()
    df_mean['a_ai'] = df_mean['a_ai'] / 1000000
    df_mean['a_ot'] = df_mean['a_ot'] / 1000000
    df_mean = df_mean.rename(columns={'a_ai': 'A_AI', 'a_ot': 'A_OT'})
    plt.figure(figsize=(8, 5))
    ffig, aax = plt.subplots()
    # Convert 'yearmonth' column to pandas datetime format
    df_mean['yearmonth'] = pd.to_datetime(df_mean['yearmonth'])
    # Plot the data
    aax.plot(df_mean['yearmonth'], df_mean['A_AI'], color=plots_colors['AI'], label='A_AI')
    aax.plot(df_mean['yearmonth'], df_mean['A_OT'], color=plots_colors['OldTech'], label='A_OT')
    plt.xlabel('Date', fontsize=font)
    plt.title('AI and OT productivity', fontsize=font)
    plt.ylabel('Productivity (M)', fontsize=font)
    # Set the x-axis tick locator and formatter
    aax.xaxis.set_major_locator(mdates.YearLocator())
    aax.xaxis.set_major_formatter(mdates.DateFormatter('%Y'))
    plt.xticks(fontsize=font)
    plt.yticks(fontsize=font)
    plt.legend(fontsize=font)
    plt.savefig(os.path.join(path, 'Figure_6b.jpg'), bbox_inches='tight')
    plt.show()


def plot_value_data(df, path, plots_colors, font=18):
    """
    Plot the cumulative value of data over time
    :param df: the characteristics under optimal parameter combination
    :param path: (str) path to save output plots
    :param plots_colors: (dict) dictionary containing plot colors by variable name
    :param font: (int) font of table text
    :return: plot and save locally
    """
    aa = df.groupby('yearmonth').sum().reset_index()
    aa = aa.sort_values('yearmonth')
    plt.figure(figsize=(8, 5))
    ffig, aax = plt.subplots()
    aa['yearmonth'] = pd.to_datetime(aa['yearmonth'])
    plt.plot(aa.yearmonth, aa['data_value'], color=plots_colors['Data_value'], lw=2, alpha=0.6)
    # Get the y-axis label object
    y_label = aax.yaxis.get_label()
    # Set the font size
    y_label.set_fontsize(font)
    plt.ylabel('Value of data', fontsize=font)
    plt.xlabel('Date', fontsize=font)
    plt.title('Cumulative value of data over time', fontsize=font)
    plt.xticks(fontsize=font)
    plt.yticks(fontsize=font)
    # Set the x-axis tick locator and formatter
    aax.xaxis.set_major_locator(mdates.YearLocator())
    aax.xaxis.set_major_formatter(mdates.DateFormatter('%Y'))
    plt.savefig(os.path.join(path, 'Figure_6a.jpg'), bbox_inches='tight')
    plt.show()
    return aa['data_value'].iloc[-1]


def df_unpack(df):
    """
    Unpack variables in input DataFrame into separate numpy vectors to be used as inputs in the optimization function
    :param df: (DataFrame) input DataFrame
    :return: a series of same length numpy vectors
    """
    df = df.sort_values(['Employer_ID', 'yearmonth']).reset_index()
    df = df.rename(columns={'GS1M': 'r',
                            'w_AI': 'w_ai', 'w_OT': 'w_ot', 'w_DM': 'w_dm',
                            'L_AI': 'l_ai',  'L_OldTech': 'l_ot', 'L_DataMgmt': 'l_dm',
                            'Employer_ID': 'emp'})
    r_p = df.r
    emp_p = df.emp
    w_ai_p = df.w_ai
    l_ai_p = df.l_ai
    w_ot_p = df.w_ot
    l_ot_p = df.l_ot
    w_dm_p = df.w_dm
    l_dm_p = df.l_dm
    yearmonth_p = df.yearmonth
    return r_p, w_ai_p, l_ai_p, w_ot_p, l_ot_p, w_dm_p, l_dm_p, yearmonth_p, emp_p
